import decimal
import os
from decimal import Decimal, getcontext, ROUND_HALF_UP
import pandas as pd
from pyspark.sql import functions as F
from pyspark.sql import SparkSession

# 锁定远端操作环境, 避免存在多个版本环境的问题
os.environ['SPARK_HOME'] = '/export/server/spark'
os.environ["PYSPARK_PYTHON"] = "/root/anaconda3/bin/python"
os.environ["PYSPARK_DRIVER_PYTHON"] = "/root/anaconda3/bin/python"

# 快捷键:  main 回车
if __name__ == '__main__':
    # 定义一个UDAF函数用于计算: lx_d dx_d dx_ci
    @F.pandas_udf('string')
    def udaf_3col(lx_d: pd.Series, qx_d: pd.Series, qx_ci: pd.Series) -> str:
        tmp_lx_d = decimal.Decimal(0)
        tmp_dx_d = decimal.Decimal(0)
        tmp_dx_ci = decimal.Decimal(0)

        for i in range(0, len(lx_d)):
            if i == 0:
                tmp_lx_d = decimal.Decimal(lx_d[i])
                tmp_dx_d = decimal.Decimal(qx_d[i])
                tmp_dx_ci = decimal.Decimal(qx_ci[i])
            else:
                tmp_lx_d = (tmp_lx_d - tmp_dx_d - tmp_dx_ci).quantize(decimal.Decimal('0.000000000000'))
                tmp_dx_d = (tmp_lx_d * qx_d[i]).quantize(decimal.Decimal('0.000000000000'))
                tmp_dx_ci = (tmp_lx_d * qx_ci[i]).quantize(decimal.Decimal('0.000000000000'))

        return str(tmp_lx_d) + ',' + str(tmp_dx_d) + ',' + str(tmp_dx_ci)


    spark.udf.register('udaf_3col', udaf_3col)
